using LinearAlgebra
using KrylovKit
using DataFrames
using CSV

println()
println("Starting XY8...")
println()
# Tau choosen as half-time for convenience
tau_arr = [3.6099999999999997e-5, 3.6899999999999996e-5, 3.77e-5, 3.85e-5, 3.929999999999999e-5, 4.01e-5, 4.09e-5, 4.1700000000000004e-5, 4.2499999999999996e-5, 4.3299999999999995e-5, 4.41e-5, 4.4899999999999994e-5, 4.57e-5, 4.65e-5, 4.73e-5, 4.81e-5, 4.8899999999999996e-5, 4.97e-5, 5.05e-5, 5.129999999999999e-5, 5.21e-5]
number_of_xy8s = 6;

function core_unitary(
    U_x::Matrix{T},
    U_y::Matrix{T},
    U_tau::Matrix{T},
    U_2tau::Matrix{T},
) where T <: Number
    U_full = U_tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_tau;

    return U_full
end

function ggms(d::Int64)::Vector{}

    ggms = Vector{}();
    # first (d^2-d)/2 elements are the "R" operators
    for aa = 1:d
        bsa = spzeros(d,1); bsa[aa,1] = 1;
        for bb = (aa+1):d
            bsb = spzeros(d,1); bsb[bb,1] = 1;
            push!(ggms, (bsb*bsa'+bsa*bsb')/sqrt(2));
        end
    end

    # second (d^2-d)/2 elements are the "L" operators
    for aa = 1:d
        bsa = spzeros(d,1); bsa[aa,1] = 1;
        for bb = 1:(aa-1)
            bsb = spzeros(d,1); bsb[bb,1] = 1;
            push!(ggms, (bsb*bsa'-bsa*bsb')*(-1im)/sqrt(2) );
        end
    end

    # third, the d-1 cartan generators  (Cartan-Weyl basis)
    for mm = 1:(d-1)
          delem = sqrt(1/(mm*(mm+1))).*[ones(1,mm) -mm zeros(1,d-mm-1)];
          push!(ggms, sparse(d:-1:1, d:-1:1, vec(delem)););
      end

      push!(ggms, sparse(I,d,d)/sqrt(d));

    return ggms

end


function full_unitary(
    U_x_2::Matrix{ComplexF64},
    U_x::Matrix{ComplexF64},
    U_y::Matrix{ComplexF64},
    U_tau::Matrix{ComplexF64},
    U_2tau::Matrix{ComplexF64},
)
    U_full = U_x_2 *
        U_tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_2tau *
        U_y *
        U_2tau *
        U_x *
        U_tau *
        U_x_2;

    return U_full
end

p0_arr = zeros(length(tau_arr))
p1_arr = zeros(length(tau_arr))

dtau = tau_arr[2] - tau_arr[1]
@assert isapprox(tau_arr, tau_arr[1]:dtau:(tau_arr[end]+dtau/2))
@time evol_tau1_xy8 = exp(liouville_free * tau_arr[1])
@time evol_dtau_xy8 = exp(liouville_free * dtau)

println("Computing time evolution...")
flush(stdout)
@time for (idx_tau, tau) in enumerate(tau_arr)
    p0_tmp = zeros(1)
    p1_tmp = zeros(1)
    @assert tau ≈ tau_arr[1] + (idx_tau-1) * dtau
    evol_tau = evol_dtau_xy8^(idx_tau-1) * evol_tau1_xy8
    evol_2tau = evol_tau^2
    evol_core = core_unitary(evol_x, evol_y, evol_tau, evol_2tau)
    rho0 = zeros(ComplexF64, size(H_x)...)
    for ho_level in 1:num_lvls
        psi0 = zeros(ComplexF64, size(H_x, 1));
        psi0[ho_level] = 1;
        state_probability = 1 / partition_sum * exp(-beta_w_units * energies_N0[ho_level])

        rho0 .+= state_probability .* psi0 .* psi0'
    end

    @assert tr(rho0) ≈ 1
    psit = tr.(GG .* [rho0])

    psit = evol_x_2 * psit
    psit = evol_core^number_of_xy8s * psit
    psit = evol_x_2 * psit
    p0 = p0_op_vec'psit
    p1 = p1_op_vec'psit

    p0_arr[idx_tau] = p0
    p1_arr[idx_tau] = p1
end

function sum_chi2(data_df, theory)
    residues_unweighted = data_df.mean - theory
    error_arr = [res < 0 ? data_df.errorUpper[ii] : data_df.errorLower[ii] for (ii,res) in enumerate(residues_unweighted)]

    sum(residues_unweighted.^2 ./ error_arr.^2)
end

flush(stdout)

# Match experimental definition of tau
output_df = DataFrame(
    time=2tau_arr,
    p0=p0_arr,
    p1=p1_arr,
)

mkpath(save_dir)
ofn = "$save_dir/xy8-$pstring.csv"
CSV.write(ofn, output_df)
